Welcome to the Eager Few Shot Object Detection Colab --- in this colab we demonstrate fine tuning of a (TF2 friendly) RetinaNet architecture on very few examples of a novel class after initializing from a pre-trained COCO checkpoint. Training runs in eager mode.
Estimated time to run through this colab (with GPU): < 5 minutes.
!pip install -U --pre tensorflow=="2.2.0"
Collecting tensorflow==2.2.0
Downloading https://files.pythonhosted.org/packages/3d/be/679ce5254a8c8d07470efb4a4c00345fae91f766e64f1c2aece8796d7218/tensorflow-2.2.0-cp36-cp36m-manylinux2010_x86_64.whl (516.2MB)
|████████████████████████████████| 516.2MB 29kB/s
Requirement already satisfied, skipping upgrade: gast==0.3.3 in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.2.0) (0.3.3)
Requirement already satisfied, skipping upgrade: numpy<2.0,>=1.16.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.2.0) (1.19.5)
Requirement already satisfied, skipping upgrade: opt-einsum>=2.3.2 in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.2.0) (3.3.0)
Requirement already satisfied, skipping upgrade: scipy==1.4.1; python_version >= "3" in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.2.0) (1.4.1)
Collecting tensorflow-estimator<2.3.0,>=2.2.0
Downloading https://files.pythonhosted.org/packages/a4/f5/926ae53d6a226ec0fda5208e0e581cffed895ccc89e36ba76a8e60895b78/tensorflow_estimator-2.2.0-py2.py3-none-any.whl (454kB)
|████████████████████████████████| 460kB 50.3MB/s
Requirement already satisfied, skipping upgrade: astunparse==1.6.3 in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.2.0) (1.6.3)
Requirement already satisfied, skipping upgrade: termcolor>=1.1.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.2.0) (1.1.0)
Requirement already satisfied, skipping upgrade: six>=1.12.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.2.0) (1.15.0)
Collecting tensorboard<2.3.0,>=2.2.0
Downloading https://files.pythonhosted.org/packages/1d/74/0a6fcb206dcc72a6da9a62dd81784bfdbff5fedb099982861dc2219014fb/tensorboard-2.2.2-py3-none-any.whl (3.0MB)
|████████████████████████████████| 3.0MB 49.9MB/s
Requirement already satisfied, skipping upgrade: absl-py>=0.7.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.2.0) (0.10.0)
Requirement already satisfied, skipping upgrade: h5py<2.11.0,>=2.10.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.2.0) (2.10.0)
Requirement already satisfied, skipping upgrade: keras-preprocessing>=1.1.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.2.0) (1.1.2)
Requirement already satisfied, skipping upgrade: google-pasta>=0.1.8 in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.2.0) (0.2.0)
Requirement already satisfied, skipping upgrade: wrapt>=1.11.1 in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.2.0) (1.12.1)
Requirement already satisfied, skipping upgrade: grpcio>=1.8.6 in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.2.0) (1.32.0)
Requirement already satisfied, skipping upgrade: wheel>=0.26; python_version >= "3" in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.2.0) (0.36.2)
Requirement already satisfied, skipping upgrade: protobuf>=3.8.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow==2.2.0) (3.12.4)
Requirement already satisfied, skipping upgrade: requests<3,>=2.21.0 in /usr/local/lib/python3.6/dist-packages (from tensorboard<2.3.0,>=2.2.0->tensorflow==2.2.0) (2.23.0)
Requirement already satisfied, skipping upgrade: markdown>=2.6.8 in /usr/local/lib/python3.6/dist-packages (from tensorboard<2.3.0,>=2.2.0->tensorflow==2.2.0) (3.3.3)
Requirement already satisfied, skipping upgrade: setuptools>=41.0.0 in /usr/local/lib/python3.6/dist-packages (from tensorboard<2.3.0,>=2.2.0->tensorflow==2.2.0) (51.1.2)
Requirement already satisfied, skipping upgrade: google-auth<2,>=1.6.3 in /usr/local/lib/python3.6/dist-packages (from tensorboard<2.3.0,>=2.2.0->tensorflow==2.2.0) (1.17.2)
Requirement already satisfied, skipping upgrade: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.6/dist-packages (from tensorboard<2.3.0,>=2.2.0->tensorflow==2.2.0) (1.7.0)
Requirement already satisfied, skipping upgrade: werkzeug>=0.11.15 in /usr/local/lib/python3.6/dist-packages (from tensorboard<2.3.0,>=2.2.0->tensorflow==2.2.0) (1.0.1)
Requirement already satisfied, skipping upgrade: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.6/dist-packages (from tensorboard<2.3.0,>=2.2.0->tensorflow==2.2.0) (0.4.2)
Requirement already satisfied, skipping upgrade: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard<2.3.0,>=2.2.0->tensorflow==2.2.0) (2020.12.5)
Requirement already satisfied, skipping upgrade: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard<2.3.0,>=2.2.0->tensorflow==2.2.0) (1.24.3)
Requirement already satisfied, skipping upgrade: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard<2.3.0,>=2.2.0->tensorflow==2.2.0) (3.0.4)
Requirement already satisfied, skipping upgrade: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard<2.3.0,>=2.2.0->tensorflow==2.2.0) (2.10)
Requirement already satisfied, skipping upgrade: importlib-metadata; python_version < "3.8" in /usr/local/lib/python3.6/dist-packages (from markdown>=2.6.8->tensorboard<2.3.0,>=2.2.0->tensorflow==2.2.0) (3.3.0)
Requirement already satisfied, skipping upgrade: rsa<5,>=3.1.4; python_version >= "3" in /usr/local/lib/python3.6/dist-packages (from google-auth<2,>=1.6.3->tensorboard<2.3.0,>=2.2.0->tensorflow==2.2.0) (4.6)
Requirement already satisfied, skipping upgrade: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.6/dist-packages (from google-auth<2,>=1.6.3->tensorboard<2.3.0,>=2.2.0->tensorflow==2.2.0) (4.2.0)
Requirement already satisfied, skipping upgrade: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.6/dist-packages (from google-auth<2,>=1.6.3->tensorboard<2.3.0,>=2.2.0->tensorflow==2.2.0) (0.2.8)
Requirement already satisfied, skipping upgrade: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.6/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard<2.3.0,>=2.2.0->tensorflow==2.2.0) (1.3.0)
Requirement already satisfied, skipping upgrade: typing-extensions>=3.6.4; python_version < "3.8" in /usr/local/lib/python3.6/dist-packages (from importlib-metadata; python_version < "3.8"->markdown>=2.6.8->tensorboard<2.3.0,>=2.2.0->tensorflow==2.2.0) (3.7.4.3)
Requirement already satisfied, skipping upgrade: zipp>=0.5 in /usr/local/lib/python3.6/dist-packages (from importlib-metadata; python_version < "3.8"->markdown>=2.6.8->tensorboard<2.3.0,>=2.2.0->tensorflow==2.2.0) (3.4.0)
Requirement already satisfied, skipping upgrade: pyasn1>=0.1.3 in /usr/local/lib/python3.6/dist-packages (from rsa<5,>=3.1.4; python_version >= "3"->google-auth<2,>=1.6.3->tensorboard<2.3.0,>=2.2.0->tensorflow==2.2.0) (0.4.8)
Requirement already satisfied, skipping upgrade: oauthlib>=3.0.0 in /usr/local/lib/python3.6/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard<2.3.0,>=2.2.0->tensorflow==2.2.0) (3.1.0)
Installing collected packages: tensorflow-estimator, tensorboard, tensorflow
Found existing installation: tensorflow-estimator 2.4.0
Uninstalling tensorflow-estimator-2.4.0:
Successfully uninstalled tensorflow-estimator-2.4.0
Found existing installation: tensorboard 2.4.0
Uninstalling tensorboard-2.4.0:
Successfully uninstalled tensorboard-2.4.0
Found existing installation: tensorflow 2.4.0
Uninstalling tensorflow-2.4.0:
Successfully uninstalled tensorflow-2.4.0
Successfully installed tensorboard-2.2.2 tensorflow-2.2.0 tensorflow-estimator-2.2.0
import os
import pathlib
# Clone the tensorflow models repository if it doesn't already exist
if "models" in pathlib.Path.cwd().parts:
while "models" in pathlib.Path.cwd().parts:
os.chdir('..')
elif not pathlib.Path('models').exists():
!git clone --depth 1 https://github.com/tensorflow/models
Cloning into 'models'... remote: Enumerating objects: 2404, done. remote: Counting objects: 100% (2404/2404), done. remote: Compressing objects: 100% (1998/1998), done. remote: Total 2404 (delta 567), reused 1445 (delta 379), pack-reused 0 Receiving objects: 100% (2404/2404), 30.78 MiB | 38.34 MiB/s, done. Resolving deltas: 100% (567/567), done.
# Install the Object Detection API
%%bash
cd models/research/
protoc object_detection/protos/*.proto --python_out=.
cp object_detection/packages/tf2/setup.py .
python -m pip install .
Processing /content/models/research
Collecting avro-python3
Downloading https://files.pythonhosted.org/packages/3f/84/ef37f882a7d93674d6fe1aa6e99f18cf2f34e9b775952f3d85587c11c92e/avro-python3-1.10.1.tar.gz
Collecting apache-beam
Downloading https://files.pythonhosted.org/packages/ba/f8/afcd9a457d92bd7858df5ece8f8ecae7d7387fbbb6e3438c6cb495278743/apache_beam-2.27.0-cp36-cp36m-manylinux2010_x86_64.whl (9.0MB)
Requirement already satisfied: pillow in /usr/local/lib/python3.6/dist-packages (from object-detection==0.1) (7.0.0)
Requirement already satisfied: lxml in /usr/local/lib/python3.6/dist-packages (from object-detection==0.1) (4.2.6)
Requirement already satisfied: matplotlib in /usr/local/lib/python3.6/dist-packages (from object-detection==0.1) (3.2.2)
Requirement already satisfied: Cython in /usr/local/lib/python3.6/dist-packages (from object-detection==0.1) (0.29.21)
Requirement already satisfied: contextlib2 in /usr/local/lib/python3.6/dist-packages (from object-detection==0.1) (0.5.5)
Collecting tf-slim
Downloading https://files.pythonhosted.org/packages/02/97/b0f4a64df018ca018cc035d44f2ef08f91e2e8aa67271f6f19633a015ff7/tf_slim-1.1.0-py2.py3-none-any.whl (352kB)
Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from object-detection==0.1) (1.15.0)
Requirement already satisfied: pycocotools in /usr/local/lib/python3.6/dist-packages (from object-detection==0.1) (2.0.2)
Collecting lvis
Downloading https://files.pythonhosted.org/packages/72/b6/1992240ab48310b5360bfdd1d53163f43bb97d90dc5dc723c67d41c38e78/lvis-0.5.3-py3-none-any.whl
Requirement already satisfied: scipy in /usr/local/lib/python3.6/dist-packages (from object-detection==0.1) (1.4.1)
Requirement already satisfied: pandas in /usr/local/lib/python3.6/dist-packages (from object-detection==0.1) (1.1.5)
Collecting tf-models-official
Downloading https://files.pythonhosted.org/packages/57/4a/23a08f8fd2747867ee223612e219eeb0d11c36116601d99b55ef3c72e707/tf_models_official-2.4.0-py2.py3-none-any.whl (1.1MB)
Collecting requests<3.0.0,>=2.24.0
Downloading https://files.pythonhosted.org/packages/29/c1/24814557f1d22c56d50280771a17307e6bf87b70727d975fd6b2ce6b014a/requests-2.25.1-py2.py3-none-any.whl (61kB)
Collecting hdfs<3.0.0,>=2.1.0
Downloading https://files.pythonhosted.org/packages/82/39/2c0879b1bcfd1f6ad078eb210d09dbce21072386a3997074ee91e60ddc5a/hdfs-2.5.8.tar.gz (41kB)
Requirement already satisfied: httplib2<0.18.0,>=0.8 in /usr/local/lib/python3.6/dist-packages (from apache-beam->object-detection==0.1) (0.17.4)
Collecting future<1.0.0,>=0.18.2
Downloading https://files.pythonhosted.org/packages/45/0b/38b06fd9b92dc2b68d58b75f900e97884c45bedd2ff83203d933cf5851c9/future-0.18.2.tar.gz (829kB)
Requirement already satisfied: pymongo<4.0.0,>=3.8.0 in /usr/local/lib/python3.6/dist-packages (from apache-beam->object-detection==0.1) (3.11.2)
Requirement already satisfied: crcmod<2.0,>=1.7 in /usr/local/lib/python3.6/dist-packages (from apache-beam->object-detection==0.1) (1.7)
Requirement already satisfied: numpy<2,>=1.14.3 in /usr/local/lib/python3.6/dist-packages (from apache-beam->object-detection==0.1) (1.19.5)
Requirement already satisfied: python-dateutil<3,>=2.8.0 in /usr/local/lib/python3.6/dist-packages (from apache-beam->object-detection==0.1) (2.8.1)
Requirement already satisfied: oauth2client<5,>=2.0.1 in /usr/local/lib/python3.6/dist-packages (from apache-beam->object-detection==0.1) (4.1.3)
Collecting pyarrow<3.0.0,>=0.15.1
Downloading https://files.pythonhosted.org/packages/d7/e1/27958a70848f8f7089bff8d6ebe42519daf01f976d28b481e1bfd52c8097/pyarrow-2.0.0-cp36-cp36m-manylinux2014_x86_64.whl (17.7MB)
Requirement already satisfied: typing-extensions<3.8.0,>=3.7.0 in /usr/local/lib/python3.6/dist-packages (from apache-beam->object-detection==0.1) (3.7.4.3)
Requirement already satisfied: protobuf<4,>=3.12.2 in /usr/local/lib/python3.6/dist-packages (from apache-beam->object-detection==0.1) (3.12.4)
Collecting dill<0.3.2,>=0.3.1.1
Downloading https://files.pythonhosted.org/packages/c7/11/345f3173809cea7f1a193bfbf02403fff250a3360e0e118a1630985e547d/dill-0.3.1.1.tar.gz (151kB)
Requirement already satisfied: grpcio<2,>=1.29.0 in /usr/local/lib/python3.6/dist-packages (from apache-beam->object-detection==0.1) (1.32.0)
Collecting mock<3.0.0,>=1.0.1
Downloading https://files.pythonhosted.org/packages/e6/35/f187bdf23be87092bd0f1200d43d23076cee4d0dec109f195173fd3ebc79/mock-2.0.0-py2.py3-none-any.whl (56kB)
Collecting fastavro<2,>=0.21.4
Downloading https://files.pythonhosted.org/packages/09/03/6f57c4f4fe3e1bb02af78f28dae440f71961296c99dbbc9d5972ead53924/fastavro-1.2.4-cp36-cp36m-manylinux2014_x86_64.whl (2.0MB)
Requirement already satisfied: pytz>=2018.3 in /usr/local/lib/python3.6/dist-packages (from apache-beam->object-detection==0.1) (2018.9)
Requirement already satisfied: pydot<2,>=1.2.0 in /usr/local/lib/python3.6/dist-packages (from apache-beam->object-detection==0.1) (1.3.0)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib->object-detection==0.1) (0.10.0)
Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->object-detection==0.1) (2.4.7)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->object-detection==0.1) (1.3.1)
Requirement already satisfied: absl-py>=0.2.2 in /usr/local/lib/python3.6/dist-packages (from tf-slim->object-detection==0.1) (0.10.0)
Requirement already satisfied: setuptools>=18.0 in /usr/local/lib/python3.6/dist-packages (from pycocotools->object-detection==0.1) (51.1.2)
Requirement already satisfied: opencv-python>=4.1.0.25 in /usr/local/lib/python3.6/dist-packages (from lvis->object-detection==0.1) (4.1.2.30)
Requirement already satisfied: psutil>=5.4.3 in /usr/local/lib/python3.6/dist-packages (from tf-models-official->object-detection==0.1) (5.4.8)
Requirement already satisfied: tensorflow-datasets in /usr/local/lib/python3.6/dist-packages (from tf-models-official->object-detection==0.1) (4.0.1)
Requirement already satisfied: gin-config in /usr/local/lib/python3.6/dist-packages (from tf-models-official->object-detection==0.1) (0.4.0)
Requirement already satisfied: dataclasses in /usr/local/lib/python3.6/dist-packages (from tf-models-official->object-detection==0.1) (0.8)
Requirement already satisfied: google-cloud-bigquery>=0.31.0 in /usr/local/lib/python3.6/dist-packages (from tf-models-official->object-detection==0.1) (1.21.0)
Requirement already satisfied: google-api-python-client>=1.6.7 in /usr/local/lib/python3.6/dist-packages (from tf-models-official->object-detection==0.1) (1.7.12)
Requirement already satisfied: tensorflow-addons in /usr/local/lib/python3.6/dist-packages (from tf-models-official->object-detection==0.1) (0.8.3)
Requirement already satisfied: kaggle>=1.3.9 in /usr/local/lib/python3.6/dist-packages (from tf-models-official->object-detection==0.1) (1.5.10)
Collecting sentencepiece
Downloading https://files.pythonhosted.org/packages/14/67/e42bd1181472c95c8cda79305df848264f2a7f62740995a46945d9797b67/sentencepiece-0.1.95-cp36-cp36m-manylinux2014_x86_64.whl (1.2MB)
Collecting pyyaml>=5.1
Downloading https://files.pythonhosted.org/packages/8a/f2/d2567a720055e0dc09254edd0db810cfdaae01b922fb5bfd6178631b689a/PyYAML-5.4-cp36-cp36m-manylinux1_x86_64.whl (640kB)
Collecting tensorflow>=2.4.0
Downloading https://files.pythonhosted.org/packages/7a/ce/e76c4e3d2c245f4f20eff1bf9cbcc602109448142881e1f946ba2d7327bb/tensorflow-2.4.0-cp36-cp36m-manylinux2010_x86_64.whl (394.7MB)
Collecting opencv-python-headless
Downloading https://files.pythonhosted.org/packages/96/fc/4da675cc522a749ebbcf85c5a63fba844b2d44c87e6f24e3fdb147df3270/opencv_python_headless-4.5.1.48-cp36-cp36m-manylinux2014_x86_64.whl (37.6MB)
Collecting py-cpuinfo>=3.3.0
Downloading https://files.pythonhosted.org/packages/f6/f5/8e6e85ce2e9f6e05040cf0d4e26f43a4718bcc4bce988b433276d4b1a5c1/py-cpuinfo-7.0.0.tar.gz (95kB)
Collecting tensorflow-model-optimization>=0.4.1
Downloading https://files.pythonhosted.org/packages/55/38/4fd48ea1bfcb0b6e36d949025200426fe9c3a8bfae029f0973d85518fa5a/tensorflow_model_optimization-0.5.0-py2.py3-none-any.whl (172kB)
Collecting seqeval
Downloading https://files.pythonhosted.org/packages/9d/2d/233c79d5b4e5ab1dbf111242299153f3caddddbb691219f363ad55ce783d/seqeval-1.2.2.tar.gz (43kB)
Requirement already satisfied: tensorflow-hub>=0.6.0 in /usr/local/lib/python3.6/dist-packages (from tf-models-official->object-detection==0.1) (0.11.0)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests<3.0.0,>=2.24.0->apache-beam->object-detection==0.1) (1.24.3)
Requirement already satisfied: chardet<5,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests<3.0.0,>=2.24.0->apache-beam->object-detection==0.1) (3.0.4)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests<3.0.0,>=2.24.0->apache-beam->object-detection==0.1) (2.10)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests<3.0.0,>=2.24.0->apache-beam->object-detection==0.1) (2020.12.5)
Requirement already satisfied: docopt in /usr/local/lib/python3.6/dist-packages (from hdfs<3.0.0,>=2.1.0->apache-beam->object-detection==0.1) (0.6.2)
Requirement already satisfied: pyasn1>=0.1.7 in /usr/local/lib/python3.6/dist-packages (from oauth2client<5,>=2.0.1->apache-beam->object-detection==0.1) (0.4.8)
Requirement already satisfied: pyasn1-modules>=0.0.5 in /usr/local/lib/python3.6/dist-packages (from oauth2client<5,>=2.0.1->apache-beam->object-detection==0.1) (0.2.8)
Requirement already satisfied: rsa>=3.1.4 in /usr/local/lib/python3.6/dist-packages (from oauth2client<5,>=2.0.1->apache-beam->object-detection==0.1) (4.6)
Collecting pbr>=0.11
Downloading https://files.pythonhosted.org/packages/fb/48/69046506f6ac61c1eaa9a0d42d22d54673b69e176d30ca98e3f61513e980/pbr-5.5.1-py2.py3-none-any.whl (106kB)
Requirement already satisfied: termcolor in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets->tf-models-official->object-detection==0.1) (1.1.0)
Requirement already satisfied: importlib-resources; python_version < "3.9" in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets->tf-models-official->object-detection==0.1) (4.1.1)
Requirement already satisfied: attrs>=18.1.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets->tf-models-official->object-detection==0.1) (20.3.0)
Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets->tf-models-official->object-detection==0.1) (4.41.1)
Requirement already satisfied: promise in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets->tf-models-official->object-detection==0.1) (2.3)
Requirement already satisfied: dm-tree in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets->tf-models-official->object-detection==0.1) (0.1.5)
Requirement already satisfied: tensorflow-metadata in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets->tf-models-official->object-detection==0.1) (0.26.0)
Requirement already satisfied: google-cloud-core<2.0dev,>=1.0.3 in /usr/local/lib/python3.6/dist-packages (from google-cloud-bigquery>=0.31.0->tf-models-official->object-detection==0.1) (1.0.3)
Requirement already satisfied: google-resumable-media!=0.4.0,<0.5.0dev,>=0.3.1 in /usr/local/lib/python3.6/dist-packages (from google-cloud-bigquery>=0.31.0->tf-models-official->object-detection==0.1) (0.4.1)
Requirement already satisfied: google-auth-httplib2>=0.0.3 in /usr/local/lib/python3.6/dist-packages (from google-api-python-client>=1.6.7->tf-models-official->object-detection==0.1) (0.0.4)
Requirement already satisfied: uritemplate<4dev,>=3.0.0 in /usr/local/lib/python3.6/dist-packages (from google-api-python-client>=1.6.7->tf-models-official->object-detection==0.1) (3.0.1)
Requirement already satisfied: google-auth>=1.4.1 in /usr/local/lib/python3.6/dist-packages (from google-api-python-client>=1.6.7->tf-models-official->object-detection==0.1) (1.17.2)
Requirement already satisfied: typeguard in /usr/local/lib/python3.6/dist-packages (from tensorflow-addons->tf-models-official->object-detection==0.1) (2.7.1)
Requirement already satisfied: python-slugify in /usr/local/lib/python3.6/dist-packages (from kaggle>=1.3.9->tf-models-official->object-detection==0.1) (4.0.1)
Requirement already satisfied: wheel~=0.35 in /usr/local/lib/python3.6/dist-packages (from tensorflow>=2.4.0->tf-models-official->object-detection==0.1) (0.36.2)
Collecting tensorboard~=2.4
Downloading https://files.pythonhosted.org/packages/64/21/eebd23060763fedeefb78bc2b286e00fa1d8abda6f70efa2ee08c28af0d4/tensorboard-2.4.1-py3-none-any.whl (10.6MB)
Requirement already satisfied: flatbuffers~=1.12.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow>=2.4.0->tf-models-official->object-detection==0.1) (1.12)
Requirement already satisfied: opt-einsum~=3.3.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow>=2.4.0->tf-models-official->object-detection==0.1) (3.3.0)
Requirement already satisfied: keras-preprocessing~=1.1.2 in /usr/local/lib/python3.6/dist-packages (from tensorflow>=2.4.0->tf-models-official->object-detection==0.1) (1.1.2)
Requirement already satisfied: google-pasta~=0.2 in /usr/local/lib/python3.6/dist-packages (from tensorflow>=2.4.0->tf-models-official->object-detection==0.1) (0.2.0)
Requirement already satisfied: gast==0.3.3 in /usr/local/lib/python3.6/dist-packages (from tensorflow>=2.4.0->tf-models-official->object-detection==0.1) (0.3.3)
Collecting tensorflow-estimator<2.5.0,>=2.4.0rc0
Downloading https://files.pythonhosted.org/packages/74/7e/622d9849abf3afb81e482ffc170758742e392ee129ce1540611199a59237/tensorflow_estimator-2.4.0-py2.py3-none-any.whl (462kB)
Requirement already satisfied: h5py~=2.10.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow>=2.4.0->tf-models-official->object-detection==0.1) (2.10.0)
Requirement already satisfied: astunparse~=1.6.3 in /usr/local/lib/python3.6/dist-packages (from tensorflow>=2.4.0->tf-models-official->object-detection==0.1) (1.6.3)
Requirement already satisfied: wrapt~=1.12.1 in /usr/local/lib/python3.6/dist-packages (from tensorflow>=2.4.0->tf-models-official->object-detection==0.1) (1.12.1)
Requirement already satisfied: scikit-learn>=0.21.3 in /usr/local/lib/python3.6/dist-packages (from seqeval->tf-models-official->object-detection==0.1) (0.22.2.post1)
Requirement already satisfied: zipp>=0.4; python_version < "3.8" in /usr/local/lib/python3.6/dist-packages (from importlib-resources; python_version < "3.9"->tensorflow-datasets->tf-models-official->object-detection==0.1) (3.4.0)
Requirement already satisfied: googleapis-common-protos<2,>=1.52.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-metadata->tensorflow-datasets->tf-models-official->object-detection==0.1) (1.52.0)
Requirement already satisfied: google-api-core<2.0.0dev,>=1.14.0 in /usr/local/lib/python3.6/dist-packages (from google-cloud-core<2.0dev,>=1.0.3->google-cloud-bigquery>=0.31.0->tf-models-official->object-detection==0.1) (1.16.0)
Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.6/dist-packages (from google-auth>=1.4.1->google-api-python-client>=1.6.7->tf-models-official->object-detection==0.1) (4.2.0)
Requirement already satisfied: text-unidecode>=1.3 in /usr/local/lib/python3.6/dist-packages (from python-slugify->kaggle>=1.3.9->tf-models-official->object-detection==0.1) (1.3)
Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.6/dist-packages (from tensorboard~=2.4->tensorflow>=2.4.0->tf-models-official->object-detection==0.1) (1.0.1)
Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.6/dist-packages (from tensorboard~=2.4->tensorflow>=2.4.0->tf-models-official->object-detection==0.1) (3.3.3)
Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.6/dist-packages (from tensorboard~=2.4->tensorflow>=2.4.0->tf-models-official->object-detection==0.1) (1.7.0)
Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.6/dist-packages (from tensorboard~=2.4->tensorflow>=2.4.0->tf-models-official->object-detection==0.1) (0.4.2)
Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.6/dist-packages (from scikit-learn>=0.21.3->seqeval->tf-models-official->object-detection==0.1) (1.0.0)
Requirement already satisfied: importlib-metadata; python_version < "3.8" in /usr/local/lib/python3.6/dist-packages (from markdown>=2.6.8->tensorboard~=2.4->tensorflow>=2.4.0->tf-models-official->object-detection==0.1) (3.3.0)
Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.6/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard~=2.4->tensorflow>=2.4.0->tf-models-official->object-detection==0.1) (1.3.0)
Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.6/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard~=2.4->tensorflow>=2.4.0->tf-models-official->object-detection==0.1) (3.1.0)
Building wheels for collected packages: object-detection, avro-python3, hdfs, future, dill, py-cpuinfo, seqeval
Building wheel for object-detection (setup.py): started
Building wheel for object-detection (setup.py): finished with status 'done'
Created wheel for object-detection: filename=object_detection-0.1-cp36-none-any.whl size=1613291 sha256=1e2e11428ceb7a471d2bc32a227c282a0b008a9dcf267214f54d387c29f08ab4
Stored in directory: /tmp/pip-ephem-wheel-cache-zufywh56/wheels/94/49/4b/39b051683087a22ef7e80ec52152a27249d1a644ccf4e442ea
Building wheel for avro-python3 (setup.py): started
Building wheel for avro-python3 (setup.py): finished with status 'done'
Created wheel for avro-python3: filename=avro_python3-1.10.1-cp36-none-any.whl size=43734 sha256=12fc3c3f2e6ee694641ae3ae2c1d79d211a76e41434882409844e857308d45e6
Stored in directory: /root/.cache/pip/wheels/65/fe/90/20d6d6d97223d80d20cb390be636619c536edab5658c12bdba
Building wheel for hdfs (setup.py): started
Building wheel for hdfs (setup.py): finished with status 'done'
Created wheel for hdfs: filename=hdfs-2.5.8-cp36-none-any.whl size=33213 sha256=8930f22d93eb0495a6e217d0aad7b92a64cc772d40fc3ee3d6026170b0d66f5e
Stored in directory: /root/.cache/pip/wheels/fe/a7/05/23e3699975fc20f8a30e00ac1e515ab8c61168e982abe4ce70
Building wheel for future (setup.py): started
Building wheel for future (setup.py): finished with status 'done'
Created wheel for future: filename=future-0.18.2-cp36-none-any.whl size=491057 sha256=2d35165d7d55c07b38ff831564af74b3b9ec09d633e4d876ab1f9ac7ff2ead93
Stored in directory: /root/.cache/pip/wheels/8b/99/a0/81daf51dcd359a9377b110a8a886b3895921802d2fc1b2397e
Building wheel for dill (setup.py): started
Building wheel for dill (setup.py): finished with status 'done'
Created wheel for dill: filename=dill-0.3.1.1-cp36-none-any.whl size=78533 sha256=2623ff9894532c2a2aa7e18a3f9f10415b8e2f7ba65f87cdbd853a084f734cc1
Stored in directory: /root/.cache/pip/wheels/59/b1/91/f02e76c732915c4015ab4010f3015469866c1eb9b14058d8e7
Building wheel for py-cpuinfo (setup.py): started
Building wheel for py-cpuinfo (setup.py): finished with status 'done'
Created wheel for py-cpuinfo: filename=py_cpuinfo-7.0.0-cp36-none-any.whl size=20072 sha256=11978d123310182c9e50300129b9dac43c8555088af11ab3ef8acbe06d3ff4a3
Stored in directory: /root/.cache/pip/wheels/f1/93/7b/127daf0c3a5a49feb2fecd468d508067c733fba5192f726ad1
Building wheel for seqeval (setup.py): started
Building wheel for seqeval (setup.py): finished with status 'done'
Created wheel for seqeval: filename=seqeval-1.2.2-cp36-none-any.whl size=16171 sha256=04cae8b3389180971fe76a260dfba317e748aa0af22292fd56cf324045a7e122
Stored in directory: /root/.cache/pip/wheels/52/df/1b/45d75646c37428f7e626214704a0e35bd3cfc32eda37e59e5f
Successfully built object-detection avro-python3 hdfs future dill py-cpuinfo seqeval
Installing collected packages: avro-python3, requests, hdfs, future, pyarrow, dill, pbr, mock, fastavro, apache-beam, tf-slim, lvis, sentencepiece, pyyaml, tensorboard, tensorflow-estimator, tensorflow, opencv-python-headless, py-cpuinfo, tensorflow-model-optimization, seqeval, tf-models-official, object-detection
Found existing installation: requests 2.23.0
Uninstalling requests-2.23.0:
Successfully uninstalled requests-2.23.0
Found existing installation: future 0.16.0
Uninstalling future-0.16.0:
Successfully uninstalled future-0.16.0
Found existing installation: pyarrow 0.14.1
Uninstalling pyarrow-0.14.1:
Successfully uninstalled pyarrow-0.14.1
Found existing installation: dill 0.3.3
Uninstalling dill-0.3.3:
Successfully uninstalled dill-0.3.3
Found existing installation: PyYAML 3.13
Uninstalling PyYAML-3.13:
Successfully uninstalled PyYAML-3.13
Found existing installation: tensorboard 2.2.2
Uninstalling tensorboard-2.2.2:
Successfully uninstalled tensorboard-2.2.2
Found existing installation: tensorflow-estimator 2.2.0
Uninstalling tensorflow-estimator-2.2.0:
Successfully uninstalled tensorflow-estimator-2.2.0
Found existing installation: tensorflow 2.2.0
Uninstalling tensorflow-2.2.0:
Successfully uninstalled tensorflow-2.2.0
Successfully installed apache-beam-2.27.0 avro-python3-1.10.1 dill-0.3.1.1 fastavro-1.2.4 future-0.18.2 hdfs-2.5.8 lvis-0.5.3 mock-2.0.0 object-detection-0.1 opencv-python-headless-4.5.1.48 pbr-5.5.1 py-cpuinfo-7.0.0 pyarrow-2.0.0 pyyaml-5.4 requests-2.25.1 sentencepiece-0.1.95 seqeval-1.2.2 tensorboard-2.4.1 tensorflow-2.4.0 tensorflow-estimator-2.4.0 tensorflow-model-optimization-0.5.0 tf-models-official-2.4.0 tf-slim-1.1.0
ERROR: multiprocess 0.70.11.1 has requirement dill>=0.3.3, but you'll have dill 0.3.1.1 which is incompatible. ERROR: google-colab 1.0.0 has requirement requests~=2.23.0, but you'll have requests 2.25.1 which is incompatible. ERROR: datascience 0.10.6 has requirement folium==0.2.1, but you'll have folium 0.8.3 which is incompatible. ERROR: apache-beam 2.27.0 has requirement avro-python3!=1.9.2,<1.10.0,>=1.8.1, but you'll have avro-python3 1.10.1 which is incompatible.
import matplotlib
import matplotlib.pyplot as plt
import os
import random
import io
import imageio
import glob
import scipy.misc
import numpy as np
from six import BytesIO
from PIL import Image, ImageDraw, ImageFont
from IPython.display import display, Javascript
from IPython.display import Image as IPyImage
import tensorflow as tf
from object_detection.utils import label_map_util
from object_detection.utils import config_util
from object_detection.utils import visualization_utils as viz_utils
from object_detection.utils import colab_utils
from object_detection.builders import model_builder
%matplotlib inline
def load_image_into_numpy_array(path):
"""Load an image from file into a numpy array.
Puts image into numpy array to feed into tensorflow graph.
Note that by convention we put it into a numpy array with shape
(height, width, channels), where channels=3 for RGB.
Args:
path: a file path.
Returns:
uint8 numpy array with shape (img_height, img_width, 3)
"""
img_data = tf.io.gfile.GFile(path, 'rb').read()
image = Image.open(BytesIO(img_data))
(im_width, im_height) = image.size
return np.array(image.getdata()).reshape(
(im_height, im_width, 3)).astype(np.uint8)
def plot_detections(image_np,
boxes,
classes,
scores,
category_index,
figsize=(12, 16),
image_name=None):
"""Wrapper function to visualize detections.
Args:
image_np: uint8 numpy array with shape (img_height, img_width, 3)
boxes: a numpy array of shape [N, 4]
classes: a numpy array of shape [N]. Note that class indices are 1-based,
and match the keys in the label map.
scores: a numpy array of shape [N] or None. If scores=None, then
this function assumes that the boxes to be plotted are groundtruth
boxes and plot all boxes as black with no classes or scores.
category_index: a dict containing category dictionaries (each holding
category index `id` and category name `name`) keyed by category indices.
figsize: size for the figure.
image_name: a name for the image file.
"""
image_np_with_annotations = image_np.copy()
viz_utils.visualize_boxes_and_labels_on_image_array(
image_np_with_annotations,
boxes,
classes,
scores,
category_index,
use_normalized_coordinates=True,
min_score_thresh=0.8)
if image_name:
plt.imsave(image_name, image_np_with_annotations)
else:
plt.imshow(image_np_with_annotations)
# Load images and visualize
train_image_dir = 'models/research/object_detection/test_images/ducky/train/'
train_images_np = []
for i in range(1, 6):
image_path = os.path.join(train_image_dir, 'robertducky' + str(i) + '.jpg')
train_images_np.append(load_image_into_numpy_array(image_path))
plt.rcParams['axes.grid'] = False
plt.rcParams['xtick.labelsize'] = False
plt.rcParams['ytick.labelsize'] = False
plt.rcParams['xtick.top'] = False
plt.rcParams['xtick.bottom'] = False
plt.rcParams['ytick.left'] = False
plt.rcParams['ytick.right'] = False
plt.rcParams['figure.figsize'] = [14, 7]
for idx, train_image_np in enumerate(train_images_np):
plt.subplot(2, 3, idx+1)
plt.imshow(train_image_np)
plt.show()
In this cell you will annotate the rubber duckies --- draw a box around the rubber ducky in each image; click next image to go to the next image and submit when there are no more images.
If you'd like to skip the manual annotation step, we totally understand. In this case, simply skip this cell and run the next cell instead, where we've prepopulated the groundtruth with pre-annotated bounding boxes.
gt_boxes = []
colab_utils.annotate(train_images_np, box_storage_pointer=gt_boxes)
Run this cell only if you didn't annotate anything above and would prefer to just use our preannotated boxes. Don't forget to uncomment.
# gt_boxes = [
# np.array([[0.436, 0.591, 0.629, 0.712]], dtype=np.float32),
# np.array([[0.539, 0.583, 0.73, 0.71]], dtype=np.float32),
# np.array([[0.464, 0.414, 0.626, 0.548]], dtype=np.float32),
# np.array([[0.313, 0.308, 0.648, 0.526]], dtype=np.float32),
# np.array([[0.256, 0.444, 0.484, 0.629]], dtype=np.float32)
# ]
Below we add the class annotations (for simplicity, we assume a single class in this colab; though it should be straightforward to extend this to handle multiple classes). We also convert everything to the format that the training loop below expects (e.g., everything converted to tensors, classes converted to one-hot representations, etc.).
# By convention, our non-background classes start counting at 1. Given
# that we will be predicting just one class, we will therefore assign it a
# `class id` of 1.
duck_class_id = 1
num_classes = 1
category_index = {duck_class_id: {'id': duck_class_id, 'name': 'rubber_ducky'}}
# Convert class labels to one-hot; convert everything to tensors.
# The `label_id_offset` here shifts all classes by a certain number of indices;
# we do this here so that the model receives one-hot labels where non-background
# classes start counting at the zeroth index. This is ordinarily just handled
# automatically in our training binaries, but we need to reproduce it here.
label_id_offset = 1
train_image_tensors = []
gt_classes_one_hot_tensors = []
gt_box_tensors = []
for (train_image_np, gt_box_np) in zip(
train_images_np, gt_boxes):
train_image_tensors.append(tf.expand_dims(tf.convert_to_tensor(
train_image_np, dtype=tf.float32), axis=0))
gt_box_tensors.append(tf.convert_to_tensor(gt_box_np, dtype=tf.float32))
zero_indexed_groundtruth_classes = tf.convert_to_tensor(
np.ones(shape=[gt_box_np.shape[0]], dtype=np.int32) - label_id_offset)
gt_classes_one_hot_tensors.append(tf.one_hot(
zero_indexed_groundtruth_classes, num_classes))
print('Done prepping data.')
Done prepping data.
dummy_scores = np.array([1.0], dtype=np.float32) # give boxes a score of 100%
plt.figure(figsize=(30, 15))
for idx in range(5):
plt.subplot(2, 3, idx+1)
plot_detections(
train_images_np[idx],
gt_boxes[idx],
np.ones(shape=[gt_boxes[idx].shape[0]], dtype=np.int32),
dummy_scores, category_index)
plt.show()
In this cell we build a single stage detection architecture (RetinaNet) and restore all but the classification layer at the top (which will be automatically randomly initialized).
For simplicity, we have hardcoded a number of things in this colab for the specific RetinaNet architecture at hand (including assuming that the image size will always be 640x640), however it is not difficult to generalize to other model configurations.
# Download the checkpoint and put it into models/research/object_detection/test_data/
!wget http://download.tensorflow.org/models/object_detection/tf2/20200711/ssd_resnet50_v1_fpn_640x640_coco17_tpu-8.tar.gz
!tar -xf ssd_resnet50_v1_fpn_640x640_coco17_tpu-8.tar.gz
!mv ssd_resnet50_v1_fpn_640x640_coco17_tpu-8/checkpoint models/research/object_detection/test_data/
--2021-01-20 04:55:44-- http://download.tensorflow.org/models/object_detection/tf2/20200711/ssd_resnet50_v1_fpn_640x640_coco17_tpu-8.tar.gz Resolving download.tensorflow.org (download.tensorflow.org)... 172.217.13.240, 2607:f8b0:4004:809::2010 Connecting to download.tensorflow.org (download.tensorflow.org)|172.217.13.240|:80... connected. HTTP request sent, awaiting response... 200 OK Length: 244817203 (233M) [application/x-tar] Saving to: ‘ssd_resnet50_v1_fpn_640x640_coco17_tpu-8.tar.gz’ ssd_resnet50_v1_fpn 100%[===================>] 233.48M 227MB/s in 1.0s 2021-01-20 04:55:45 (227 MB/s) - ‘ssd_resnet50_v1_fpn_640x640_coco17_tpu-8.tar.gz’ saved [244817203/244817203]
tf.keras.backend.clear_session()
print('Building model and restoring weights for fine-tuning...', flush=True)
num_classes = 1
pipeline_config = 'models/research/object_detection/configs/tf2/ssd_resnet50_v1_fpn_640x640_coco17_tpu-8.config'
checkpoint_path = 'models/research/object_detection/test_data/checkpoint/ckpt-0'
# Load pipeline config and build a detection model.
#
# Since we are working off of a COCO architecture which predicts 90
# class slots by default, we override the `num_classes` field here to be just
# one (for our new rubber ducky class).
configs = config_util.get_configs_from_pipeline_file(pipeline_config)
model_config = configs['model']
model_config.ssd.num_classes = num_classes
model_config.ssd.freeze_batchnorm = True
detection_model = model_builder.build(
model_config=model_config, is_training=True)
# Set up object-based checkpoint restore --- RetinaNet has two prediction
# `heads` --- one for classification, the other for box regression. We will
# restore the box regression head but initialize the classification head
# from scratch (we show the omission below by commenting out the line that
# we would add if we wanted to restore both heads)
fake_box_predictor = tf.compat.v2.train.Checkpoint(
_base_tower_layers_for_heads=detection_model._box_predictor._base_tower_layers_for_heads,
# _prediction_heads=detection_model._box_predictor._prediction_heads,
# (i.e., the classification head that we *will not* restore)
_box_prediction_head=detection_model._box_predictor._box_prediction_head,
)
fake_model = tf.compat.v2.train.Checkpoint(
_feature_extractor=detection_model._feature_extractor,
_box_predictor=fake_box_predictor)
ckpt = tf.compat.v2.train.Checkpoint(model=fake_model)
ckpt.restore(checkpoint_path).expect_partial()
# Run model through a dummy image so that variables are created
image, shapes = detection_model.preprocess(tf.zeros([1, 640, 640, 3]))
prediction_dict = detection_model.predict(image, shapes)
_ = detection_model.postprocess(prediction_dict, shapes)
print('Weights restored!')
Building model and restoring weights for fine-tuning... Weights restored!
tf.keras.backend.set_learning_phase(True)
# These parameters can be tuned; since our training set has 5 images
# it doesn't make sense to have a much larger batch size, though we could
# fit more examples in memory if we wanted to.
batch_size = 4
learning_rate = 0.01
num_batches = 100
# Select variables in top layers to fine-tune.
trainable_variables = detection_model.trainable_variables
to_fine_tune = []
prefixes_to_train = [
'WeightSharedConvolutionalBoxPredictor/WeightSharedConvolutionalBoxHead',
'WeightSharedConvolutionalBoxPredictor/WeightSharedConvolutionalClassHead']
for var in trainable_variables:
if any([var.name.startswith(prefix) for prefix in prefixes_to_train]):
to_fine_tune.append(var)
# Set up forward + backward pass for a single train step.
def get_model_train_step_function(model, optimizer, vars_to_fine_tune):
"""Get a tf.function for training step."""
# Use tf.function for a bit of speed.
# Comment out the tf.function decorator if you want the inside of the
# function to run eagerly.
@tf.function
def train_step_fn(image_tensors,
groundtruth_boxes_list,
groundtruth_classes_list):
"""A single training iteration.
Args:
image_tensors: A list of [1, height, width, 3] Tensor of type tf.float32.
Note that the height and width can vary across images, as they are
reshaped within this function to be 640x640.
groundtruth_boxes_list: A list of Tensors of shape [N_i, 4] with type
tf.float32 representing groundtruth boxes for each image in the batch.
groundtruth_classes_list: A list of Tensors of shape [N_i, num_classes]
with type tf.float32 representing groundtruth boxes for each image in
the batch.
Returns:
A scalar tensor representing the total loss for the input batch.
"""
shapes = tf.constant(batch_size * [[640, 640, 3]], dtype=tf.int32)
model.provide_groundtruth(
groundtruth_boxes_list=groundtruth_boxes_list,
groundtruth_classes_list=groundtruth_classes_list)
with tf.GradientTape() as tape:
preprocessed_images = tf.concat(
[detection_model.preprocess(image_tensor)[0]
for image_tensor in image_tensors], axis=0)
prediction_dict = model.predict(preprocessed_images, shapes)
losses_dict = model.loss(prediction_dict, shapes)
total_loss = losses_dict['Loss/localization_loss'] + losses_dict['Loss/classification_loss']
gradients = tape.gradient(total_loss, vars_to_fine_tune)
optimizer.apply_gradients(zip(gradients, vars_to_fine_tune))
return total_loss
return train_step_fn
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate, momentum=0.9)
train_step_fn = get_model_train_step_function(
detection_model, optimizer, to_fine_tune)
print('Start fine-tuning!', flush=True)
for idx in range(num_batches):
# Grab keys for a random subset of examples
all_keys = list(range(len(train_images_np)))
random.shuffle(all_keys)
example_keys = all_keys[:batch_size]
# Note that we do not do data augmentation in this demo. If you want a
# a fun exercise, we recommend experimenting with random horizontal flipping
# and random cropping :)
gt_boxes_list = [gt_box_tensors[key] for key in example_keys]
gt_classes_list = [gt_classes_one_hot_tensors[key] for key in example_keys]
image_tensors = [train_image_tensors[key] for key in example_keys]
# Training step (forward pass + backwards pass)
total_loss = train_step_fn(image_tensors, gt_boxes_list, gt_classes_list)
if idx % 10 == 0:
print('batch ' + str(idx) + ' of ' + str(num_batches)
+ ', loss=' + str(total_loss.numpy()), flush=True)
print('Done fine-tuning!')
Start fine-tuning!
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/backend.py:434: UserWarning: `tf.keras.backend.set_learning_phase` is deprecated and will be removed after 2020-10-11. To update it, simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model.
warnings.warn('`tf.keras.backend.set_learning_phase` is deprecated and '
batch 0 of 100, loss=1.1604084 batch 10 of 100, loss=0.07527775 batch 20 of 100, loss=0.020109622 batch 30 of 100, loss=0.010939896 batch 40 of 100, loss=0.0066314433 batch 50 of 100, loss=0.005506818 batch 60 of 100, loss=0.0046734633 batch 70 of 100, loss=0.004406075 batch 80 of 100, loss=0.00373389 batch 90 of 100, loss=0.0027542359 Done fine-tuning!
test_image_dir = 'models/research/object_detection/test_images/ducky/test/'
test_images_np = []
for i in range(1, 50):
image_path = os.path.join(test_image_dir, 'out' + str(i) + '.jpg')
test_images_np.append(np.expand_dims(
load_image_into_numpy_array(image_path), axis=0))
# Again, uncomment this decorator if you want to run inference eagerly
@tf.function
def detect(input_tensor):
"""Run detection on an input image.
Args:
input_tensor: A [1, height, width, 3] Tensor of type tf.float32.
Note that height and width can be anything since the image will be
immediately resized according to the needs of the model within this
function.
Returns:
A dict containing 3 Tensors (`detection_boxes`, `detection_classes`,
and `detection_scores`).
"""
preprocessed_image, shapes = detection_model.preprocess(input_tensor)
prediction_dict = detection_model.predict(preprocessed_image, shapes)
return detection_model.postprocess(prediction_dict, shapes)
# Note that the first frame will trigger tracing of the tf.function, which will
# take some time, after which inference should be fast.
label_id_offset = 1
for i in range(len(test_images_np)):
input_tensor = tf.convert_to_tensor(test_images_np[i], dtype=tf.float32)
detections = detect(input_tensor)
plot_detections(
test_images_np[i][0],
detections['detection_boxes'][0].numpy(),
detections['detection_classes'][0].numpy().astype(np.uint32)
+ label_id_offset,
detections['detection_scores'][0].numpy(),
category_index, figsize=(15, 20), image_name="gif_frame_" + ('%02d' % i) + ".jpg")
imageio.plugins.freeimage.download()
anim_file = 'duckies_test.gif'
filenames = glob.glob('gif_frame_*.jpg')
filenames = sorted(filenames)
last = -1
images = []
for filename in filenames:
image = imageio.imread(filename)
images.append(image)
imageio.mimsave(anim_file, images, 'GIF-FI', fps=5)
display(IPyImage(open(anim_file, 'rb').read()))
Imageio: 'libfreeimage-3.16.0-linux64.so' was not found on your computer; downloading it now. Try 1. Download from https://github.com/imageio/imageio-binaries/raw/master/freeimage/libfreeimage-3.16.0-linux64.so (4.6 MB) Downloading: 8192/4830080 bytes (0.2%)3670016/4830080 bytes (76.0%)4830080/4830080 bytes (100.0%) Done File saved as /root/.imageio/freeimage/libfreeimage-3.16.0-linux64.so.